【题解】P6633 [ZJOI2020] 抽卡

原题链接

90 分做法,所以这篇就不投题解了

f(i)f(i) 为抽了 ii 张卡还没有成功的方案数, mmkk 与题面意义相同,则有

ans=i=0mif(i)ans=\sum_{i=0}^\infty m^{-i}f(i)

尝试展开 f(i)f(i),考虑枚举抽了 ii 轮之后手里有的不同的卡的数量是 pp。同时设 F(i,p)F(i,p) 代表限定了 pp 张卡抽了 ii 轮每张卡都抽到的方案数,G(p)G(p) 代表选择了 pp 张不同的卡没有出现 kk 连号的方案数:

ans=i=0mif(i)=1+i=1mip=1mF(i,p)G(p)ans=\sum_{i=0}^\infty m^{-i}f(i)=1+\sum_{i=1}^\infty m^{-i}\sum_{p=1}^mF(i,p)G(p)

(这里把 i=0i=0 特判出来了,因为留着它不利于后面式子展开)

FF 是一个经典的球盒模型,它等于 ii 个不同的球放进 pp 个不同的盒子,且每个盒子都有球的方案数。很容易容斥求出:

F(i,p)=t=0p1(1)tC(p,t)(pt)iF(i,p)=\sum_{t=0}^{p-1}(-1)^tC(p,t)(p-t)^i

对原式改变求和顺序,把对 ii 的求和挪到里面,然后利用等比数列无穷项的求和公式:

i=1miF(i,p)=t=0p1(1)tC(p,t)ptmp+t\sum_{i=1}^\infty m^{-i}F(i,p)=\sum_{t=0}^{p-1}(-1)^tC(p,t)\frac{p-t}{m-p+t}

把组合数展开,发现这玩意是一个卷积的形式,还可以用 NTT 优化,以 O(mlogm)O(m\log m) 的复杂度计算出对于所有 p[1,m]p\in[1,m] 的值。我们先把他记作 H(p)H(p) 然后放在这。

ans=1+p=1mG(p)H(p)ans=1+\sum_{p=1}^mG(p)H(p)

接下来优化 GG。这东西可以分段计算,原卡组形成了若干个连续段,可以单独计算每个连续段。在一个长度为 nn 的段里选择选择 pp 张卡没有 kk 连号的方案数,隔板法转化一下其实也是球盒模型,它等于 pp 个一样的球丢进 np+1n-p+1 个不同的盒,且每个盒里的球小于 kk 个的方案数。照样容斥计算。

t=0pk(1)tC(np+1,t)C(ntk,np)\sum_{t=0}^{\lfloor\frac{p}{k}\rfloor}(-1)^tC(n-p+1,t)C(n-tk,n-p)

直接照着这个式子计算,总体时间复杂度是 O(m2k)O(\frac{m^2}{k}),能拿90分(省选要是能拿 90 分够香了吧)。满分做法就是把这个地方改成生成函数求,然而我并不会(

把每个连续段的答案合并其实就是背包合并,还是可以 NTT 优化。把两个大小为 xxyy 的背包合并需要花费 O((x+y)log)O((x+y)\log) 的时间,形成一个大小为 x+yx+y 的背包。这明显就是 P1090 合并果子 的模型,开个优先队列,每次合并两个最小的,可以证明这样做的总复杂度是 O(mlog2m)O(m\log^2m)。(倒数第 ii 次合并的两个背包长度和不会超过 2mi+1\frac{2m}{i+1},总长度和是 O(mlogm)O(m\log m)

全都合并起来之后剩下的那个背包就是 GG 的值,和刚才求出的 HH 做个点积加上 11 输出就做完了。

#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
#define M 200001
#define L (1<<19)
#define P 998244353
int m, k; bool mk[2*M];
int pow(int b, int p) {
	if (p < 0) p += P-1;
	long long a = b, ret = 1;
	while (p) {
		if (p & 1) ret = ret * a % P;
		a = a * a % P;
		p >>= 1;
	}
	return ret;
}
int ln2(int x) {return 32-__builtin_clz(x-1); }
int r[L], a[L], b[L];
void NTT(int* a, int f, int n) {
	for (int i = 0; i < n; i++)
		if (i < r[i]) std::swap(a[i], a[r[i]]);
	for (int l = 1; 2*l <= n; l <<= 1) {
		int w = pow(3, f*(P-1)/(2*l));
		for (int i = 0; i < n; i += 2*l) {
			long long d = 1;
			for (int k = 0; k < l; k++) {
				int u = a[i+k], v = a[i+l+k];
				a[i+k] = (d * v + u) % P;
				a[i+l+k] = ((P-d) * v + u) % P;
				d = d * w % P;
			}
		}			
	}
}
void mul(int ln) {
	#define N (1<<ln)
	for (int i = 1; i < N; i++)
		r[i] = (r[i>>1]>>1) | ((i&1)<<(ln-1));
	NTT(a, 1, N);
	NTT(b, 1, N);
	long long x = pow(1<<ln, -1);
	for (int i = 0; i < N; i++)
		a[i] = x * a[i] % P * b[i] % P;
	NTT(a, -1, N);
}
std::vector<int> G[M]; int sz = 0;
int fi[M], fv[M], H[M];
int C(int n, int m) {
	return 1ll * fi[n] * fv[m] % P * fv[n-m] % P;
}
int T1(int n, int p) {
	int rs = 0;
	for (int t = 0; t <= n-p+1 && t*k <= p; t++) {
		int xx = 1ll * C(n-p+1, t) * C(n-t*k, n-p) % P;
		if (t & 1) xx = P-xx;
		rs = (rs + xx) % P;
	}
	return rs;
}
void work(int x) {
	G[sz].push_back(1);
	for (int i = 1; i <= x; i++)
		G[sz].push_back(T1(x, i));
	while (!G[sz].back()) G[sz].pop_back();
	sz++;
}
int main() {
	scanf("%d%d", &m, &k);
	fi[0] = 1;
	for (int i = 1; i <= m; i++)
		fi[i] = 1ll * fi[i-1] * i % P;
	fv[m] = pow(fi[m], -1);
	for (int i = m; i >= 1; i--)
		fv[i-1] = 1ll * fv[i] * i % P;
	for (int i = 0, x; i < m; i++)
		{scanf("%d", &x); mk[x] = true; }
	for (int i = 2*m, x = 0; i >= 0; i--) {
		if (!mk[i]) {if (x) work(x); x = 0; }
		else x++;
	}
	for (int i = 0; i < m-1; i++) {
		a[i] = fv[i];
		if (i & 1) a[i] = P-a[i];
		b[i] = 1ll * fv[i] * pow(m-i-1, -1) % P;
	}
	int ln = ln2(2*m-3);
	mul(ln);
	for (int i = 1; i < m; i++)
		H[i] = 1ll * fi[i] * a[i-1] % P;
	auto cmp = [](int x, int y) {return G[x].size() > G[y].size(); };
	std::priority_queue<int, std::vector<int>, decltype(cmp)> pq(cmp);
	for (int i = 0; i < sz; i++) pq.push(i);
	while (pq.size() > 1) {
		int x = pq.top(); pq.pop();
		int y = pq.top(); pq.pop();
		int sx = G[x].size(), sy = G[y].size();
		int ln = ln2(sx+sy-1);
		memcpy(a, G[x].data(), sx * sizeof(int));
		memcpy(b, G[y].data(), sy * sizeof(int));
		memset(a+sx, 0, ((1<<ln)-sx) * sizeof(int));
		memset(b+sy, 0, ((1<<ln)-sy) * sizeof(int));
		mul(ln);
		G[x] = std::vector<int>(a, a+sx+sy-1);
		G[y].clear();
		pq.push(x);
	}
	int x = pq.top(), ans = 1;
	for (int i = 1; i < G[x].size(); i++)
		ans = (1ll * H[i] * G[x][i] + ans) % P;
	printf("%d\n", ans);
}